import pandas as pd
import numpy as np
from keras.utils.np_utils import to_categorical
import keras
import ctypes
task_dict = {'discard':0, 'pon':1, 'chi':2, 'riichi':3}
X_task_dict = {'discard':59, 'pon':60, 'chi':60, 'riichi':60}
y_task_dict = {'discard':34, 'pon':2, 'chi':4, 'riichi':2}
class batch_generator(keras.utils.Sequence):

    def __init__(self, index_file_name, csv_file_name, batch_size = 256, task = 'discard', data_arg_on = False, libname = '../C++/preprocess.so'):
        """
        index_file_name:index file names
        csv_file_name:csv file name
        batch_size:batch size
        task:'discard'/'pon'/'chi'/'riichi'
        data_arg_on:data argument on/off
        """
        self.batch_size = batch_size
        self.task = task_dict[task]
        self.X_plane = X_task_dict[task]
        self.y_class = y_task_dict[task]
        self.lib = ctypes.cdll.LoadLibrary(libname)
        self.data_arg_on = data_arg_on
        self.lib.init(index_file_name, csv_file_name, self.batch_size, self.task, self.data_arg_on)
    
    def __getitem__(self, index):
        """Gets batch at position `index`.

        # Arguments
            index: position of the batch in the Sequence.

        # Returns
            A batch
        """

        X = np.zeros((self.batch_size, self.X_plane, 34, 4), dtype = np.float32)
        y = np.zeros((self.batch_size), dtype = np.int32)

        temp = np.asarray(X)
        ptr_X = temp.ctypes.data_as(ctypes.c_char_p)
        temp = np.asarray(y)
        ptr_y = temp.ctypes.data_as(ctypes.c_char_p)

        self.lib.get_item(ptr_X, ptr_y, self.batch_size * index)
        y = to_categorical(y, self.y_class)
        return X, y

    def __len__(self):
        """Number of batch in the Sequence.

        # Returns
            The number of batches in the Sequence.
        """
        return self.lib.get_len()

    def on_epoch_end(self):
        """Method called at the end of every epoch.
        """
        self.lib.shuffle_index()
